import pandas as pd
import plotly.express as px
import numpy as np
import plotly.graph_objects as go
import plotly.offline as pyo
#reading data
url = "https://raw.githubusercontent.com/bcaffo/MRIcloudT1volumetrics/master/inst/extdata/multilevel_lookup_table.txt"
multilevel_lookup = pd.read_csv(url, sep = "\t").drop(['Level5'], axis = 1)
multilevel_lookup = multilevel_lookup.rename(columns = {
"modify" : "roi",
"modify.1" : "level4",
"modify.2" : "level3",
"modify.3" : "level2",
"modify.4" : "level1"})
multilevel_lookup = multilevel_lookup[['roi', 'level4', 'level3', 'level2', 'level1']]
#loading subject data
id = 127
subjectData = pd.read_csv("https://raw.githubusercontent.com/smart-stats/ds4bio_book/main/book/assetts/kirby21AllLevels.csv").drop(['Unnamed: 0'], axis = 1)
subjectData = subjectData.loc[(subjectData.type == 1) & (subjectData.level == 5) & (subjectData.id == id)]
subjectData = subjectData[['roi', 'volume']]
## Merge the subject data with the multilevel data
subjectData = pd.merge(subjectData, multilevel_lookup, on = "roi")
subjectData = subjectData.assign(icv = "ICV")
subjectData = subjectData.assign(comp = subjectData.volume / np.sum(subjectData.volume))
#grouping
data1 = subjectData.groupby(['level1', 'icv'])
data1 = data1['comp'].sum().reset_index()
data1.columns = ['target', 'source', 'value']
data1
datat = subjectData.groupby(['level2', 'icv'])
datat = datat['comp'].sum().reset_index()
datat.columns = ['target', 'source', 'value']
datat2 = subjectData.groupby(['level1', 'level2'])
datat2 = datat2['comp'].sum().reset_index()
datat2.columns = ['target', 'source', 'value']
#combining dataframes
comb = pd.concat([data1, datat, datat2], axis = 0)
comb
#identifying unique values in both source and target columns
st_names = list(pd.unique(comb[['source', 'target']].values.ravel('k')))
#creatomg mapping dictionary
md = {k: i for i, k in enumerate(st_names)}
#mapping values
comb['source'] = comb['source'].map(md)
comb['target'] = comb['target'].map(md)
#converting data into dictionary
comb_d = comb.to_dict(orient = 'list')
#plotting Sankey Diagram
fig = go.Figure(data = go.Sankey(
node = dict(
pad = 5,
thickness = 5,
line = dict(color = "blue", width = 5),
label = st_names,
color = "blue"
),
link = dict(
source = comb['source'],
target = comb['target'],
value = comb['value']
)
)
)
fig.update_layout(title_text = "Sankey Diagram",
font_size = 11)
fig.show()